{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "hypernerf_ap_ds_figure",
      "provenance": [],
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "EOnw38FUTbLI"
      },
      "source": [
        "# @title Imports\n",
        "\n",
        "from dataclasses import dataclass\n",
        "from pprint import pprint\n",
        "from typing import Any, List, Callable, Dict, Sequence\n",
        "\n",
        "import numpy as np\n",
        "\n",
        "import jax\n",
        "from jax.config import config as jax_config\n",
        "import jax.numpy as jnp\n",
        "from jax import grad, jit, vmap\n",
        "from jax import random\n",
        "\n",
        "import flax\n",
        "import flax.linen as nn\n",
        "from flax import jax_utils\n",
        "from flax import optim\n",
        "from flax.metrics import tensorboard\n",
        "from flax.training import checkpoints\n",
        "\n",
        "from absl import logging\n",
        "\n",
        "# Monkey patch logging.\n",
        "def myprint(msg, *args, **kwargs):\n",
        " print(msg % args)\n",
        "\n",
        "logging.info = myprint \n",
        "logging.warn = myprint\n",
        "\n",
        "\n",
        "import gin\n",
        "gin.enter_interactive_mode()\n",
        "\n",
        "\n",
        "def scatter_points(points, **kwargs):\n",
        "  \"\"\"Convenience function for plotting points in Plotly.\"\"\"\n",
        "  return go.Scatter3d(\n",
        "      x=points[:, 0],\n",
        "      y=points[:, 1],\n",
        "      z=points[:, 2],\n",
        "      mode='markers',\n",
        "      **kwargs,\n",
        "  )\n",
        "\n",
        "\n",
        "from IPython.core.display import display, HTML, Latex\n",
        "\n",
        "\n",
        "def Markdown(text):\n",
        "  IPython.core.display._display_mimetype('text/markdown', [text], raw=True)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "wJvpDubUf9Nr"
      },
      "source": [
        "print(jax.devices())"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "m_ngWEcyCUQR"
      },
      "source": [
        "# @title Utilities\n",
        "import contextlib\n",
        "\n",
        "\n",
        "@contextlib.contextmanager\n",
        "def plot_to_array(height, width, rows=1, cols=1, dpi=100, no_axis=False):\n",
        "  \"\"\"A context manager that plots to a numpy array.\n",
        "\n",
        "  When the context manager exits the output array will be populated with an\n",
        "  image of the plot.\n",
        "\n",
        "  Usage:\n",
        "      ```\n",
        "      with plot_to_array(480, 640, 2, 2) as (fig, axes, out_image):\n",
        "          axes[0][0].plot(...)\n",
        "      ```\n",
        "  Args:\n",
        "      height: the height of the canvas\n",
        "      width: the width of the canvas\n",
        "      rows: the number of axis rows\n",
        "      cols: the number of axis columns\n",
        "      dpi: the DPI to render at\n",
        "      no_axis: if True will hide the axes of the plot\n",
        "\n",
        "  Yields:\n",
        "    A 3-tuple of: a pyplot Figure, array of Axes, and the output np.ndarray.\n",
        "  \"\"\"\n",
        "  out_array = np.empty((height, width, 3), dtype=np.uint8)\n",
        "  fig, axes = plt.subplots(\n",
        "      rows, cols, figsize=(width / dpi, height / dpi), dpi=dpi)\n",
        "  if no_axis:\n",
        "    for ax in fig.axes:\n",
        "      ax.margins(0, 0)\n",
        "      ax.axis('off')\n",
        "      ax.get_xaxis().set_visible(False)\n",
        "      ax.get_yaxis().set_visible(False)\n",
        "\n",
        "  yield fig, axes, out_array\n",
        "\n",
        "  # If we haven't already shown or saved the plot, then we need to\n",
        "  # draw the figure first...\n",
        "  fig.tight_layout(pad=0)\n",
        "  fig.canvas.draw()\n",
        "\n",
        "  # Now we can save it to a numpy array.\n",
        "  data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n",
        "  data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))\n",
        "  plt.close()\n",
        "\n",
        "  np.copyto(out_array, data)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JIL7vtr9orwb"
      },
      "source": [
        "rng = random.PRNGKey(20200823)\n",
        "# Shift the numpy random seed by host_id() to shuffle data loaded by different\n",
        "# hosts.\n",
        "np.random.seed(20201473 + jax.host_id())"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "SnjlLfRyRXaZ"
      },
      "source": [
        "dataset_name = '' # @param {type:\"string\"}\n",
        "data_dir = gpath.GPath(dataset_name) \n",
        "print('data_dir: ', data_dir)\n",
        "# assert data_dir.exists()\n",
        "\n",
        "exp_dir = '' # @param {type:\"string\"}\n",
        "exp_dir = gpath.GPath(exp_dir)\n",
        "print('exp_dir: ', exp_dir)\n",
        "assert exp_dir.exists()\n",
        "\n",
        "config_path = exp_dir / 'config.gin'\n",
        "print('config_path', config_path)\n",
        "assert config_path.exists()\n",
        "\n",
        "checkpoint_dir = exp_dir / 'checkpoints'\n",
        "print('checkpoint_dir: ', checkpoint_dir)\n",
        "assert checkpoint_dir.exists()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Brw2WgO7pj5d"
      },
      "source": [
        "# @title Load configuration.\n",
        "reload()\n",
        "\n",
        "import IPython\n",
        "from IPython.display import display, HTML\n",
        "\n",
        "\n",
        "def config_line_predicate(l):\n",
        "  return (\n",
        "      'ExperimentConfig.camera_type' not in l\n",
        "      and 'preload_data' not in l\n",
        "      # and 'metadata_at_density' not in l\n",
        "      # and 'hyper_grad_loss_weight' not in l\n",
        "    )\n",
        "\n",
        "\n",
        "print(config_path)\n",
        "\n",
        "with config_path.open('rt') as f:\n",
        "  lines = filter(config_line_predicate, f.readlines())\n",
        "  gin_config = ''.join(lines)\n",
        "\n",
        "gin.parse_config(gin_config)\n",
        "\n",
        "exp_config = configs.ExperimentConfig()\n",
        "\n",
        "train_config = configs.TrainConfig(\n",
        "    batch_size=2048,\n",
        "    hyper_sheet_alpha_schedule=None,\n",
        ")\n",
        "eval_config = configs.EvalConfig(\n",
        "    chunk=4096,\n",
        ")\n",
        "dummy_model = models.NerfModel({}, 0, 0)\n",
        "\n",
        "display(HTML(Markdown(gin.config.markdown(gin.config_str()))))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Q6nu1E1s8Heo"
      },
      "source": [
        "datasource = exp_config.datasource_cls(\n",
        "  data_dir=data_dir,\n",
        "  image_scale=exp_config.image_scale,\n",
        "  random_seed=exp_config.random_seed,\n",
        "  # Enable metadata based on model needs.\n",
        "  use_warp_id=dummy_model.use_warp,\n",
        "  use_appearance_id=(\n",
        "      dummy_model.nerf_embed_key == 'appearance'\n",
        "      or dummy_model.hyper_embed_key == 'appearance'),\n",
        "  use_camera_id=dummy_model.nerf_embed_key == 'camera',\n",
        "  use_time=dummy_model.warp_embed_key == 'time')\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "leESeBqh0RSQ"
      },
      "source": [
        "# @title Load model\n",
        "reload()\n",
        "\n",
        "devices = jax.devices()\n",
        "rng, key = random.split(rng)\n",
        "params = {}\n",
        "model, params['model'] = models.construct_nerf(\n",
        "    key,\n",
        "    batch_size=train_config.batch_size,\n",
        "    embeddings_dict=datasource.embeddings_dict,\n",
        "    near=datasource.near,\n",
        "    far=datasource.far)\n",
        "\n",
        "optimizer_def = optim.Adam(0.0)\n",
        "if train_config.use_weight_norm:\n",
        "  optimizer_def = optim.WeightNorm(optimizer_def)\n",
        "optimizer = optimizer_def.create(params)\n",
        "state = model_utils.TrainState(\n",
        "    optimizer=optimizer,\n",
        "    warp_alpha=0.0)\n",
        "scalar_params = training.ScalarParams(\n",
        "    learning_rate=0.0,\n",
        "    elastic_loss_weight=0.0,\n",
        "    background_loss_weight=train_config.background_loss_weight)\n",
        "try:\n",
        "  state_dict = checkpoints.restore_checkpoint(checkpoint_dir, None)\n",
        "  state = state.replace(\n",
        "      optimizer=flax.serialization.from_state_dict(state.optimizer, state_dict['optimizer']),\n",
        "      warp_alpha=state_dict['warp_alpha'])\n",
        "except KeyError:\n",
        "  # Load legacy checkpoints.\n",
        "  optimizer = optimizer_def.create(params['model'])\n",
        "  state = model_utils.TrainState(optimizer=optimizer)\n",
        "  state = checkpoints.restore_checkpoint(checkpoint_dir, state)\n",
        "  state = state.replace(optimizer=state.optimizer.replace(target={'model': state.optimizer.target}))\n",
        "step = state.optimizer.state.step + 1\n",
        "state = jax_utils.replicate(state, devices=devices)\n",
        "del params"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "YgaryK1QVQaF"
      },
      "source": [
        "# @title Render function.\n",
        "import functools\n",
        "reload()\n",
        "\n",
        "use_warp = True # @param{type: 'boolean'}\n",
        "use_points = False # @param{type: 'boolean'}\n",
        "\n",
        "params = jax_utils.unreplicate(state.optimizer.target)\n",
        "\n",
        "\n",
        "def _model_fn(key_0, key_1, params, rays_dict, warp_extras):\n",
        "  out = model.apply({'params': params},\n",
        "                    rays_dict,\n",
        "                    warp_extras,\n",
        "                    rngs={\n",
        "                        'coarse': key_0,\n",
        "                        'fine': key_1\n",
        "                    },\n",
        "                    mutable=False,\n",
        "                    metadata_encoded=True,\n",
        "                    return_points=use_points,\n",
        "                    return_weights=use_points,\n",
        "                    use_warp=use_warp)\n",
        "  return jax.lax.all_gather(out, axis_name='batch')\n",
        "\n",
        "pmodel_fn = jax.pmap(\n",
        "    # Note rng_keys are useless in eval mode since there's no randomness.\n",
        "    _model_fn,\n",
        "    # key0, key1, params, rays_dict, warp_extras\n",
        "    in_axes=(0, 0, 0, 0, 0),\n",
        "    devices=devices,\n",
        "    donate_argnums=(3,),  # Donate the 'rays' argument.\n",
        "    axis_name='batch',\n",
        ")\n",
        "\n",
        "render_fn = functools.partial(evaluation.render_image,\n",
        "                              model_fn=pmodel_fn,\n",
        "                              device_count=len(devices),\n",
        "                              chunk=8192)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "a2jkqa5EGiZL"
      },
      "source": [
        "# @title Latent code utils\n",
        "\n",
        "def get_hyper_code(params, item_id):\n",
        "  appearance_id = datasource.get_appearance_id(item_id)\n",
        "  metadata = {\n",
        "      'warp': jnp.array([appearance_id], jnp.uint32),\n",
        "      'appearance': jnp.array([appearance_id], jnp.uint32),\n",
        "  }\n",
        "  return model.apply({'params': params['model']},\n",
        "                     metadata,\n",
        "                     method=model.encode_hyper_embed)\n",
        "\n",
        "\n",
        "def get_appearance_code(params, item_id):\n",
        "  appearance_id = datasource.get_appearance_id(item_id)\n",
        "  metadata = {\n",
        "      'appearance': jnp.array([appearance_id], jnp.uint32),\n",
        "  }\n",
        "  return model.apply({'params': params['model']},\n",
        "                     metadata,\n",
        "                     method=model.encode_nerf_embed)\n",
        "\n",
        "\n",
        "def get_warp_code(params, item_id):\n",
        "  warp_id = datasource.get_warp_id(item_id)\n",
        "  metadata = {\n",
        "      'warp': jnp.array([warp_id], jnp.uint32),\n",
        "  }\n",
        "  return model.apply({'params': params['model']},\n",
        "                     metadata,\n",
        "                     method=model.encode_warp_embed)\n",
        "\n",
        "def get_codes(item_id):\n",
        "  appearance_code = None\n",
        "  if model.use_rgb_condition:\n",
        "    appearance_code = get_appearance_code(params, item_id)\n",
        "  \n",
        "  warp_codes = None\n",
        "  if model.use_warp:\n",
        "    warp_code = get_warp_code(params, item_id)\n",
        " \n",
        "  hyper_codes = None\n",
        "  if model.has_hyper:\n",
        "    hyper_code = get_hyper_code(params, item_id)\n",
        "  \n",
        "  return appearance_code, warp_code, hyper_code\n",
        "\n",
        "\n",
        "def make_batch(camera, appearance_code=None, warp_code=None, hyper_code=None):\n",
        "  batch = datasets.camera_to_rays(camera)\n",
        "  batch_shape = batch['origins'][..., 0].shape\n",
        "  metadata = {}\n",
        "  if appearance_code is not None:\n",
        "      appearance_code = appearance_code.squeeze(0)\n",
        "      metadata['encoded_nerf'] = jnp.broadcast_to(\n",
        "          appearance_code[None, None, :], (*batch_shape, appearance_code.shape[-1]))\n",
        "  if warp_code is not None:\n",
        "    metadata['encoded_warp'] = jnp.broadcast_to(\n",
        "        warp_code[None, None, :], (*batch_shape, warp_code.shape[-1]))\n",
        "  batch['metadata'] = metadata\n",
        "\n",
        "  if hyper_code is not None:\n",
        "    batch['metadata']['encoded_hyper'] = jnp.broadcast_to(\n",
        "        hyper_code[None, None, :], (*batch_shape, hyper_code.shape[-1]))\n",
        "  \n",
        "  return batch"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-K19Z2Tsmlxe"
      },
      "source": [
        "# @title Manual crop\n",
        "\n",
        "render_scale = 0.5\n",
        "target_rgb = image_utils.downsample_image(datasource.load_rgb(datasource.train_ids[0]), int(1/render_scale))\n",
        "top, bottom, left, right = 2 * np.array([89, 75, 32, 26])  # K\n",
        "\n",
        "# top, bottom, left, right = 2 * np.array([60, 70, 14, 10])  # R\n",
        "# top, bottom, left, right = 0, 30, 68, 68  # lemon\n",
        "# top, bottom, left, right = 40, 100, 2, 40  # slice-banana\n",
        "target_rgb = target_rgb[top:-bottom, left:-right]\n",
        "print(target_rgb.shape)\n",
        "media.show_image(target_rgb)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0Zj3uIsjJfKO"
      },
      "source": [
        "## Hyper grid."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "c92t15kuKanf"
      },
      "source": [
        "# @title Sample points and metadata\n",
        "\n",
        "item_id = datasource.train_ids[0]\n",
        "camera = datasource.load_camera(item_id).scale(render_scale)\n",
        "camera.crop_image_domain()\n",
        "batch = make_batch(camera, *get_codes(item_id))\n",
        "origins = batch['origins']\n",
        "directions = batch['directions']\n",
        "metadata = batch['metadata']\n",
        "z_vals, points = model_utils.sample_along_rays(\n",
        "    rng, origins[None, ...], directions[None, ...], \n",
        "    model.num_coarse_samples,\n",
        "    model.near, \n",
        "    model.far, \n",
        "    model.use_stratified_sampling,\n",
        "    model.use_linear_disparity)\n",
        "points = points.reshape((-1, 3))\n",
        "points = random.permutation(rng, points)[:8096*4]\n",
        "print(points.shape)\n",
        "\n",
        "warp_metadata = random.randint(\n",
        "    key, (points.shape[0], 1), 0, model.num_warp_embeds, dtype=jnp.uint32)\n",
        "warp_embed = model.apply({'params': params['model']},\n",
        "                          {model.warp_embed_key: warp_metadata},\n",
        "                          method=model.encode_warp_embed)\n",
        "# warp_embed = jnp.broadcast_to(\n",
        "#     warp_embed[:, jnp.newaxis, :],\n",
        "#     shape=(*points.shape[:-1], warp_embed.shape[-1]))\n",
        "if model.has_hyper_embed:\n",
        "  hyper_metadata = random.randint(\n",
        "      key, (points.shape[0], 1), 0, model.num_hyper_embeds, dtype=jnp.uint32)\n",
        "  hyper_embed_key = (model.warp_embed_key if model.hyper_use_warp_embed\n",
        "                      else model.hyper_embed_key)\n",
        "  hyper_embed = model.apply({'params': params['model']},\n",
        "                            {hyper_embed_key: hyper_metadata},\n",
        "                            method=model.encode_hyper_embed)\n",
        "  # hyper_embed = jnp.broadcast_to(\n",
        "  #     hyper_embed[:, jnp.newaxis, :],\n",
        "      # shape=(*batch_shape, hyper_embed.shape[-1]))\n",
        "else:\n",
        "  hyper_embed = None\n",
        "\n",
        "map_fn = functools.partial(model.apply, method=model.map_points)\n",
        "warped_points, _ = map_fn(\n",
        "    {'params': params['model']}, \n",
        "    points[:, None], hyper_embed[:, None], warp_embed[:, None], \n",
        "    jax_utils.unreplicate(state.extra_params))\n",
        "hyper_points = np.array(warped_points[..., 3:].squeeze())\n",
        "print(hyper_points.shape)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "MvdO4IoUhSu-"
      },
      "source": [
        "# umin, vmin = hyper_points.min(axis=0)\n",
        "# umax, vmax = hyper_points.max(axis=0)\n",
        "umin, vmin = np.percentile(hyper_points[..., :2], 20, axis=0)\n",
        "umax, vmax = np.percentile(hyper_points[..., :2], 99, axis=0)\n",
        "umin, vmin, umax, vmax"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "l0k0Ciz_Wuk6"
      },
      "source": [
        "n = 7\n",
        "uu, vv = np.meshgrid(np.linspace(umin, umax, n), np.linspace(vmin, vmax, n))\n",
        "hyper_grid = np.stack([uu, vv], axis=-1)\n",
        "hyper_grid[0, 0], hyper_grid[-1, -1]"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2Rv8WW3kGibq"
      },
      "source": [
        "import itertools\n",
        "import gc\n",
        "gc.collect()\n",
        "\n",
        "grid_frames = []\n",
        "\n",
        "camera = datasource.load_camera(item_id).scale(render_scale)\n",
        "camera = camera.crop_image_domain(left, right, top, bottom)\n",
        "\n",
        "batch = make_batch(camera, *get_codes(item_id))\n",
        "batch_shape = batch['origins'][..., 0].shape\n",
        "for i, j in itertools.product(range(n), range(n)):\n",
        "  hyper_point = jnp.array(hyper_grid[i, j])\n",
        "  # hyper_point = jnp.concatenate([hyper_point, jnp.zeros((6,))])\n",
        "  hyper_point = jnp.broadcast_to(\n",
        "          hyper_point[None, None, :], \n",
        "          (*batch_shape, hyper_point.shape[-1]))\n",
        "  batch['metadata']['hyper_point'] = hyper_point\n",
        "  \n",
        "  render = render_fn(state, batch, rng=rng)\n",
        "  pred_rgb = np.array(render['rgb'])\n",
        "  pred_depth_med = np.array(render['med_depth'])\n",
        "  pred_depth_viz = viz.colorize(1.0 / pred_depth_med.squeeze())\n",
        "  del render\n",
        "  \n",
        "  media.show_images([pred_rgb, pred_depth_viz])\n",
        "  grid_frames.append({\n",
        "      'rgb': pred_rgb,\n",
        "      'depth': pred_depth_med,\n",
        "  })\n",
        "\n",
        "media.show_images([f['rgb'] for f in grid_frames], columns=n)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "nejSHA2WyAQD"
      },
      "source": [
        "from matplotlib import pyplot as plt\n",
        "from mpl_toolkits.axes_grid1 import ImageGrid\n",
        "import numpy as np\n",
        "\n",
        "fig = plt.figure(figsize=(24., 24.))\n",
        "grid = ImageGrid(fig, 111,  # similar to subplot(111)\n",
        "                 nrows_ncols=(n, n),  # creates 2x2 grid of axes\n",
        "                 axes_pad=0.1,  # pad between axes in inch.\n",
        "                 )\n",
        "\n",
        "images = [f['rgb'] for f in grid_frames]\n",
        "for ax, im in zip(grid, images):\n",
        "    # Iterating over the grid returns the Axes.\n",
        "    ax.imshow(im)\n",
        "    ax.set_axis_off()\n",
        "    ax.margins(x=0, y=0)\n",
        "    ax.get_xaxis().set_visible(False)\n",
        "    ax.get_yaxis().set_visible(False)\n",
        "    ax.set_aspect('equal')\n",
        "fig.tight_layout(pad=0)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4FC-MsLz5dAM"
      },
      "source": [
        "from scipy import interpolate\n",
        "\n",
        "num_samples = 200\n",
        "# points = np.random.uniform(0, 1, size=(10, 2))\n",
        "rng = random.PRNGKey(3)\n",
        "# points = random.uniform(rng, (20, 2))\n",
        "points = np.array([\n",
        "  [0.2, 0.1],\n",
        "  [0.2, 0.8],\n",
        "  [0.8, 0.8],\n",
        "  [0.8, 0.1],\n",
        "  [0.5, 0.1],\n",
        "  [0.2, 0.4],\n",
        "  [0.5, 0.7],\n",
        "  [0.8, 0.7],\n",
        "  [0.6, 0.2],\n",
        "  [0.2, 0.1],\n",
        "])\n",
        "t = np.arange(len(points))\n",
        "xs = np.linspace(0, len(points) - 1, num_samples)\n",
        "cs = interpolate.CubicSpline(t, points, bc_type='periodic')\n",
        "\n",
        "interp_points = cs(xs).astype(np.float32)\n",
        "fig, ax = plt.subplots()\n",
        "ax.scatter(interp_points[:, 0], interp_points[:, 1], s=2)\n",
        "ax.scatter(points[:, 0], points[:, 1])\n",
        "ax.set_aspect('equal')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "lQUcIumi7Ts8"
      },
      "source": [
        "interp_hyper_points = np.stack([(umax - umin) * interp_points[:, 0] + umin, (vmax - vmin) * interp_points[:, 1] + vmin], axis=-1)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "A7NYM0Rs_Tie"
      },
      "source": [
        "## Make Orbit Cameras"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fuZGlM90_UZ-"
      },
      "source": [
        "ref_cameras = utils.parallel_map(datasource.load_camera, datasource.all_ids)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eMsYdVEDXcFA"
      },
      "source": [
        "## Select Keyframes and Interpolate Codes"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "71g9fqINo8sC"
      },
      "source": [
        "# @title Show training frames to choose IDs\n",
        "target_ids = datasource.train_ids[::4]\n",
        "target_rgbs = utils.parallel_map(\n",
        "    lambda i: image_utils.downsample_image(datasource.load_rgb(i), int(1/render_scale)), \n",
        "    target_ids)\n",
        "media.show_images(target_rgbs, titles=target_ids, columns=20)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oSZraLc2Y7QN"
      },
      "source": [
        "## Render"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GU9PaEHmX_QU"
      },
      "source": [
        "# @title Latent code functions\n",
        "\n",
        "reload()\n",
        "\n",
        "\n",
        "def get_hyper_code(params, item_id):\n",
        "  appearance_id = datasource.get_appearance_id(item_id)\n",
        "  metadata = {\n",
        "      'warp': jnp.array([appearance_id], jnp.uint32),\n",
        "      'appearance': jnp.array([appearance_id], jnp.uint32),\n",
        "  }\n",
        "  return model.apply({'params': params['model']},\n",
        "                     metadata,\n",
        "                     method=model.encode_hyper_embed)\n",
        "\n",
        "\n",
        "def get_appearance_code(params, item_id):\n",
        "  appearance_id = datasource.get_appearance_id(item_id)\n",
        "  metadata = {\n",
        "      'appearance': jnp.array([appearance_id], jnp.uint32),\n",
        "  }\n",
        "  return model.apply({'params': params['model']},\n",
        "                     metadata,\n",
        "                     method=model.encode_nerf_embed)\n",
        "\n",
        "\n",
        "def get_warp_code(params, item_id):\n",
        "  warp_id = datasource.get_warp_id(item_id)\n",
        "  metadata = {\n",
        "      'warp': jnp.array([warp_id], jnp.uint32),\n",
        "  }\n",
        "  return model.apply({'params': params['model']},\n",
        "                     metadata,\n",
        "                     method=model.encode_warp_embed)\n",
        "\n",
        "\n",
        "params = jax_utils.unreplicate(state.optimizer.target)\n",
        "if model.use_rgb_condition:\n",
        "  test_appearance_code = get_appearance_code(params, datasource.train_ids[0])\n",
        "  print('appearance code:', test_appearance_code)\n",
        "\n",
        "if model.use_warp:\n",
        "  test_warp_code = get_warp_code(params, datasource.train_ids[0])\n",
        "  print('warp code:', test_warp_code)\n",
        "\n",
        "if model.has_hyper:\n",
        "  test_hyper_code = get_hyper_code(params, datasource.train_ids[0])\n",
        "  print('hyper code:', test_hyper_code)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3uFsdK2naA4D"
      },
      "source": [
        "# @title Render function.\n",
        "import functools\n",
        "reload()\n",
        "\n",
        "use_warp = True # @param{type: 'boolean'}\n",
        "use_points = False # @param{type: 'boolean'}\n",
        "\n",
        "params = jax_utils.unreplicate(state.optimizer.target)\n",
        "\n",
        "\n",
        "def _model_fn(key_0, key_1, params, rays_dict, warp_extras):\n",
        "  out = model.apply({'params': params},\n",
        "                    rays_dict,\n",
        "                    warp_extras,\n",
        "                    rngs={\n",
        "                        'coarse': key_0,\n",
        "                        'fine': key_1\n",
        "                    },\n",
        "                    mutable=False,\n",
        "                    metadata_encoded=True,\n",
        "                    return_points=use_points,\n",
        "                    return_weights=use_points,\n",
        "                    use_warp=use_warp)\n",
        "  return jax.lax.all_gather(out, axis_name='batch')\n",
        "\n",
        "pmodel_fn = jax.pmap(\n",
        "    # Note rng_keys are useless in eval mode since there's no randomness.\n",
        "    _model_fn,\n",
        "    # key0, key1, params, rays_dict, warp_extras\n",
        "    in_axes=(0, 0, 0, 0, 0),\n",
        "    devices=devices,\n",
        "    donate_argnums=(3,),  # Donate the 'rays' argument.\n",
        "    axis_name='batch',\n",
        ")\n",
        "\n",
        "render_fn = functools.partial(evaluation.render_image,\n",
        "                              model_fn=pmodel_fn,\n",
        "                              device_count=len(devices),\n",
        "                              chunk=8192)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fK7FZ1q3GLyi"
      },
      "source": [
        "item_ids = ['001428']\n",
        "item_ids = ['000082']\n",
        "# item_ids = ['000457']\n",
        "# item_ids = ['000429']\n",
        "# item_ids = ['000610']  # ricardo\n",
        "render_scale = 1.0\n",
        "\n",
        "media.show_images([datasource.load_rgb(x) for x in item_ids], titles=item_ids)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "MquILATfGL2x"
      },
      "source": [
        "\n",
        "import gc\n",
        "gc.collect()\n",
        "\n",
        "base_camera = datasource.load_camera(item_ids[0]).scale(render_scale)\n",
        "# base_camera = datasource.load_camera('000037').scale(render_scale)\n",
        "base_camera = datasource.load_camera('000389').scale(render_scale)\n",
        "# orbit_cameras = [c.scale(render_scale) for c in make_orbit_cameras(360)] \n",
        "# base_camera = orbit_cameras[270]\n",
        "\n",
        "out_frames = []\n",
        "for i, item_id in enumerate(item_ids):\n",
        "  camera = base_camera\n",
        "  print(f'>>> Rendering ID {item_id} <<<')\n",
        "  appearance_code = get_appearance_code(params, item_id).squeeze() if model.use_nerf_embed else None\n",
        "  warp_code = get_warp_code(params, item_id).squeeze() if model.use_warp else None\n",
        "  hyper_code = get_hyper_code(params, item_id).squeeze() if model.has_hyper_embed else None\n",
        "  batch = make_batch(camera, appearance_code, warp_code, hyper_code)\n",
        "\n",
        "  render = render_fn(state, batch, rng=rng)\n",
        "  pred_rgb = np.array(render['rgb'])\n",
        "  pred_depth_med = np.array(render['med_depth'])\n",
        "  pred_depth_viz = viz.colorize(1.0 / pred_depth_med.squeeze())\n",
        "\n",
        "  media.show_images([pred_rgb, pred_depth_viz])\n",
        "  out_frames.append({\n",
        "      'rgb': pred_rgb,\n",
        "      'depth': pred_depth_med,\n",
        "      'med_points': np.array(render['med_points']),\n",
        "  })\n",
        "  del batch, render"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_8t2T_f4IKFB"
      },
      "source": [
        "from skimage.color import hsv2rgb\n",
        "\n",
        "def sinebow(h):\n",
        "  f = lambda x : np.sin(np.pi * x)**2\n",
        "  return np.stack([f(3/6-h), f(5/6-h), f(7/6-h)], -1)\n",
        "\n",
        "\n",
        "def colorize_flow(u, v, phase=0, freq=1):\n",
        "  coords = np.stack([u, v], axis=-1)\n",
        "  mag = np.linalg.norm(coords, axis=-1) / np.sqrt(2)\n",
        "  angle = np.arctan2(-v, -u) / np.pi / (2/freq)\n",
        "  print(angle.min(), angle.max())\n",
        "  # return viz.colorize(np.log(mag+1e-6), cmap='gray')\n",
        "  colorwheel = sinebow(angle + phase/360*np.pi)\n",
        "  # brightness = mag[..., None] ** 1.414\n",
        "  brightness = mag[..., None] ** 1.0\n",
        "  # brightness = (25 * np.cbrt(mag[..., None]*100) - 17)/100\n",
        "  # brightness = (((mag[..., None]*100 + 17)/25)**3)/100\n",
        "  bg = np.ones_like(colorwheel) * 0.5\n",
        "  # bg = np.ones_like(colorwheel) * 0.0\n",
        "  return colorwheel * brightness + bg * (1.0 - brightness)\n",
        "\n",
        "  \n",
        "def visualize_hyper_points(frame):\n",
        "  hyper_points = frame['med_points'].squeeze()[..., 3:]\n",
        "  uu = (hyper_points[..., 0] - umin) / (umax - umin)\n",
        "  vv = (hyper_points[..., 1] - vmin) / (vmax - vmin)\n",
        "  normalized_hyper_points = np.stack([uu, vv], axis=-1)\n",
        "  normalized_hyper_points = (normalized_hyper_points - 0.5) * 2.0\n",
        "  print(normalized_hyper_points.min(), normalized_hyper_points.max())\n",
        "  return colorize_flow(normalized_hyper_points[..., 0], normalized_hyper_points[..., 1])\n",
        "\n",
        "\n",
        "uu = np.linspace(-1, 1, 256)\n",
        "vv = np.linspace(-1, 1, 256)\n",
        "uu, vv = np.meshgrid(uu, vv)\n",
        "\n",
        "media.show_image(colorize_flow(uu, vv))\n",
        "\n",
        "\n",
        "# media.show_image(visualize_hyper_points(out_frames[0]))\n",
        "for frame in out_frames:\n",
        "  pred_rgb = frame['rgb']\n",
        "  pred_depth = frame['depth']\n",
        "  # depth_viz = viz.colorize(1/pred_depth.squeeze(), cmin=1.6, cmax=3.0, cmap='turbo', invert=False)\n",
        "  depth_viz = viz.colorize(1/pred_depth.squeeze(), cmin=1.6, cmax=2.3, cmap='turbo', invert=False)\n",
        "  hyper_viz = visualize_hyper_points(out_frames[0])\n",
        "  media.show_images([pred_rgb, depth_viz, hyper_viz])\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Zp6vJhCEKCSn"
      },
      "source": [
        "uu = np.linspace(-1, 1, 1024)\n",
        "vv = np.linspace(-1, 1, 1024)\n",
        "uu, vv = np.meshgrid(uu, vv)\n",
        "\n",
        "media.show_image(colorize_flow(uu, vv))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GxYcw_ZiirkW"
      },
      "source": [
        "from PIL import Image, ImageDraw\n",
        "\n",
        "\n",
        "def crop_circle(img, width=3, color=(0, 0, 0)):\n",
        "  img = Image.fromarray(image_utils.image_to_uint8(img))\n",
        "  h,w=img.size\n",
        "  \n",
        "  # Create same size alpha layer with circle\n",
        "  alpha = Image.new('L', img.size,0)\n",
        "  draw = ImageDraw.Draw(alpha)\n",
        "  draw.pieslice([0,0,h,w],0,360,fill=255)\n",
        "  # Convert alpha Image to numpy array\n",
        "  npAlpha=np.array(alpha)\n",
        "\n",
        "  draw = ImageDraw.Draw(img)\n",
        "  draw.arc([0, 0, h, w], 0, 360, fill=tuple(color), width=width)\n",
        "  npImage=np.array(img)\n",
        "  \n",
        "  # Add alpha layer to RGB\n",
        "  npImage=np.dstack((npImage,npAlpha))\n",
        "  return image_utils.image_to_float32(npImage)\n",
        "\n",
        "\n",
        "media.show_image(crop_circle(images[0]))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "C0v2AYnlIXnr"
      },
      "source": [
        "from matplotlib import pyplot as plt\n",
        "from mpl_toolkits.axes_grid1 import ImageGrid\n",
        "import numpy as np\n",
        "\n",
        "fig = plt.figure(figsize=(24., 24.))\n",
        "grid = ImageGrid(fig, 111,  # similar to subplot(111)\n",
        "                 nrows_ncols=(n, n),  # creates 2x2 grid of axes\n",
        "                 axes_pad=0.1,  # pad between axes in inch.\n",
        "                 )\n",
        "\n",
        "uu = np.linspace(-1, 1, 7)\n",
        "vv = np.linspace(-1, 1, 7)\n",
        "uu, vv = np.meshgrid(uu, vv)\n",
        "grid_colors = image_utils.image_to_uint8(colorize_flow(uu, vv))\n",
        "grid_colors = grid_colors.reshape((-1, 3))\n",
        "\n",
        "images = [f['rgb'] for f in grid_frames]\n",
        "for i, (ax, im) in enumerate(zip(grid, images)):\n",
        "    # Iterating over the grid returns the Axes.\n",
        "    color = tuple(grid_colors[i])\n",
        "    ax.imshow(crop_circle(im, width=14, color=color))\n",
        "    ax.set_axis_off()\n",
        "    ax.margins(x=0, y=0)\n",
        "    ax.get_xaxis().set_visible(False)\n",
        "    ax.get_yaxis().set_visible(False)\n",
        "    ax.set_aspect('equal')\n",
        "fig.tight_layout(pad=0)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "YTMDxXK4Ig3K"
      },
      "source": [
        "(np.array([75, 140, 40, 100], dtype=np.float)*0.9).round()"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}